{
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "```\n",
        "# Copyright 2022 Google Inc.\n",
        "\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "\n",
        "#     http://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "```"
      ],
      "metadata": {
        "id": "hCuI3G1qcZJ8"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "This colab supports the UniProt launch 2022_04, where Google predicted\n",
        "protein names for 88% of all Uncharacterized proteins (over 1 in 5 proteins in UniProt).\n",
        "\n",
        "This colab allows you to run a model that's very similar to the one used in the UniProt release. **Put in the amino acid sequence below**, and press \"Runtime > Run all\" in the _File_ menu above to **get name predictions for your protein**!\n",
        "\n",
        "This colab takes a few minutes to run initially, and then you get protein sequence predictions in a few seconds!"
      ],
      "metadata": {
        "id": "swj2Jhw8cvAR"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Import code"
      ],
      "metadata": {
        "id": "Q8TtkEsX7Z2q"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "wSfRFuO5tive",
        "cellView": "form",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "69a8e10a-3a23-40bb-ee1c-9a26406c9ab1"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[K     |████████████████████████████████| 4.9 MB 10.8 MB/s \n",
            "\u001b[?25h"
          ]
        }
      ],
      "source": [
        "#@markdown Please execute this cell by pressing the _Play_ button \n",
        "#@markdown on the left to import the dependencies. It can take a few minutes.\n",
        "!python3 -m pip install -q -U tensorflow==2.8.2\n",
        "!python3 -m pip install -q -U tensorflow-text==2.8.2\n",
        "import tensorflow as tf\n",
        "import tensorflow_text\n",
        "import numpy as np\n",
        "import re\n",
        "\n",
        "import IPython.display\n",
        "from absl import logging\n",
        "\n",
        "logging.set_verbosity(logging.ERROR)  # Turn down tensorflow warnings\n",
        "\n",
        "def print_markdown(string):\n",
        "  IPython.display.display(IPython.display.Markdown(string))"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# 2. Load the model"
      ],
      "metadata": {
        "id": "g7Vxy09g6WpD"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "uIZGEiYMXqUD",
        "cellView": "form"
      },
      "outputs": [],
      "source": [
        "#@markdown Please execute this cell by pressing the _Play_ button.\n",
        "\n",
        "def query(seq):\n",
        "  return f\"[protein_name_in_english] <extra_id_0> [sequence] {seq}\"\n",
        "\n",
        "EC_NUMBER_REGEX = r'(\\d+).([\\d\\-n]+).([\\d\\-n]+).([\\d\\-n]+)'\n",
        "\n",
        "def run_inference(seq):\n",
        "  labeling = infer(tf.constant([query(seq)]))\n",
        "  names = labeling['output_0'][0].numpy().tolist()\n",
        "  scores = labeling['output_1'][0].numpy().tolist()\n",
        "  beam_size = len(names)\n",
        "  names = [names[beam_size-1-i].decode().replace('<extra_id_0> ', '') for i in range(beam_size)]\n",
        "  for i, name in enumerate(names):\n",
        "    if re.match(EC_NUMBER_REGEX, name):\n",
        "      names[i] = 'EC:' + name\n",
        "  scores = [np.exp(scores[beam_size-1-i]) for i in range(beam_size)]\n",
        "  return names, scores"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#@markdown Please execute this cell by pressing the _Play_ button \n",
        "#@markdown on the left to load the model. It can take a few minutes.\n",
        "\n",
        "! mkdir -p protnlm\n",
        "\n",
        "! wget -nc https://storage.googleapis.com/brain-genomics-public/research/proteins/protnlm/uniprot_2022_04/savedmodel__20221011__030822_1128_bs1.bm10.eos_cpu/saved_model.pb -P protnlm -q\n",
        "! mkdir -p protnlm/variables\n",
        "! wget -nc https://storage.googleapis.com/brain-genomics-public/research/proteins/protnlm/uniprot_2022_04/savedmodel__20221011__030822_1128_bs1.bm10.eos_cpu/variables/variables.index -P protnlm/variables/ -q\n",
        "! wget -nc https://storage.googleapis.com/brain-genomics-public/research/proteins/protnlm/uniprot_2022_04/savedmodel__20221011__030822_1128_bs1.bm10.eos_cpu/variables/variables.data-00000-of-00001 -P protnlm/variables/ -q\n",
        "\n",
        "imported = tf.saved_model.load(export_dir=\"protnlm\")\n",
        "infer = imported.signatures[\"serving_default\"]"
      ],
      "metadata": {
        "cellView": "form",
        "id": "N9YtieGL6YRk"
      },
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "QUDDawYq4e3B",
        "cellView": "form",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 257
        },
        "outputId": "c984158d-1558-4474-a32c-3df9a89e34ab"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "Prediction number 1: **Hemoglobin subunit alpha** with a score of **0.363**"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "Prediction number 2: **Alpha-globin** with a score of **0.161**"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "Prediction number 3: **Hemoglobin subunit alpha-like** with a score of **0.070**"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "Prediction number 4: **Hemoglobin alpha chain** with a score of **0.067**"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "Prediction number 5: **Hemoglobin alpha subunit 2** with a score of **0.034**"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "Prediction number 6: **Hemoglobin subunit alpha-1** with a score of **0.034**"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "Prediction number 7: **GLOBIN domain-containing protein** with a score of **0.027**"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "Prediction number 8: **Hemoglobin subunit alpha-2** with a score of **0.025**"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "Prediction number 9: **Hemoglobin alpha-1 chain** with a score of **0.018**"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "Prediction number 10: **Hemoglobin alpha-2 chain** with a score of **0.012**"
          },
          "metadata": {}
        }
      ],
      "source": [
        "#@title 3. Put your prediction here (hemoglobin is pre-loaded)\n",
        "\n",
        "#@markdown Press the _Play_ button to get a prediction.\n",
        "#@markdown The first time can take a few minutes.\n",
        "#@markdown \n",
        "#@markdown Subsequent predictions take a few seconds.\n",
        "sequence = \"MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHFDLSHGSAQVKGHG KKVADALTNAVAHVDDMPNALSALSDLHAHKLRVDPVNFKLLSHCLLVTLAAHLPAEFTP AVHASLDKFLASVSTVLTSKYR\" #@param {type:\"string\"}\n",
        "sequence = sequence.replace(' ', '')\n",
        "\n",
        "names, scores = run_inference(sequence)\n",
        "\n",
        "for name, score, i in zip(names, scores, range(len(names))):\n",
        "  print_markdown(f\"Prediction number {i+1}: **{name}** with a score of **{score:.03f}**\")"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#@title 3. A second example prediction (bifunctional TrpCF protein P22098)\n",
        "\n",
        "#@markdown Press the _Play_ button to get a prediction.\n",
        "sequence = \"MKMTDFNTQQANNLSEHVSKKEAEMAEVLAKIVRDKYQWVAERKASQHLSTFQSDLLPSD RSFYDALSGDKTVFITECKKASPSKGLIRNDFDLDYIASVYNNYADAISVLTDEKYFQGS FDFLPQVRRQVKQPVLCKDFMVDTYQVYLARHYGADAVLLMLSVLNDEEYKALEEAAHSL NMGILTEVSNEEELHRAVQLGARVIGINNRNLRDLTTDLNRTKALAPTIRKLAPNATVIS ESGIYTHQQVRDLAEYADGFLIGSSLMAEDNLELAVRKVTLGENKVCGLTHPDDAAKAYQ AGAVFGGLIFVEKSKRAVDFESARLTMSGAPLNYVGVFQNHDVDYVASIVTSLGLKAVQL HGLEDQEYVNQLKTELPVGVEIWKAYGVADTKPSLLADNIDRHLLDAQVGTQTGGTGHVF DWSLIGDPSQIMLAGGLSPENAQQAAKLGCLGLDLNSGVESAPGKKDSQKLQAAFHAIRN Y\" #@param {type:\"string\"}\n",
        "sequence = sequence.replace(' ', '')\n",
        "\n",
        "names, scores = run_inference(sequence)\n",
        "\n",
        "for name, score, i in zip(names, scores, range(len(names))):\n",
        "  print_markdown(f\"Prediction number {i+1}: **{name}** with a score of **{score:.03f}**\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 257
        },
        "cellView": "form",
        "id": "MKUsa5LB4ZyF",
        "outputId": "c22c8236-ecf7-4edc-c989-eee4af526f94"
      },
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "Prediction number 1: **Multifunctional fusion protein** with a score of **0.180**"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "Prediction number 2: **Indole-3-glycerol phosphate synthase** with a score of **0.167**"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "Prediction number 3: **IGPS** with a score of **0.163**"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "Prediction number 4: **EC:4.1.1.48** with a score of **0.159**"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "Prediction number 5: **N-(5'-phosphoribosyl)anthranilate isomerase** with a score of **0.109**"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "Prediction number 6: **PRAI** with a score of **0.105**"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "Prediction number 7: **EC:5.3.1.24** with a score of **0.100**"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "Prediction number 8: **Bifunctional indole-3-glycerol-phosphate synthase TrpC/phosphoribosylanthranilate isomerase TrpF** with a score of **0.002**"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "Prediction number 9: **Tryptophan biosynthesis protein TrpCF** with a score of **0.001**"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "Prediction number 10: **Anthranilate synthase component 1** with a score of **0.001**"
          },
          "metadata": {}
        }
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "gpuClass": "standard",
    "accelerator": "GPU"
  },
  "nbformat": 4,
  "nbformat_minor": 0
}